"""
Parsing module for models.py files. Extracts information in a more reliable
way than inspect + regexes.
Now only used as a fallback when introspection and the South custom hook both fail.
"""

import re
import inspect
import parser
import symbol
import token
import keyword
import datetime

from django.db import models
from django.contrib.contenttypes import generic
from django.utils.datastructures import SortedDict
from django.core.exceptions import ImproperlyConfigured


def name_that_thing(thing):
    "Turns a symbol/token int into its name."
    for name in dir(symbol):
        if getattr(symbol, name) == thing:
            return "symbol.%s" % name
    for name in dir(token):
        if getattr(token, name) == thing:
            return "token.%s" % name
    return str(thing)


def thing_that_name(name):
    "Turns a name of a symbol/token into its integer value."
    if name in dir(symbol):
        return getattr(symbol, name)
    if name in dir(token):
        return getattr(token, name)
    raise ValueError("Cannot convert '%s'" % name)


def prettyprint(tree, indent=0, omit_singles=False):
    "Prettyprints the tree, with symbol/token names. For debugging."
    if omit_singles and isinstance(tree, tuple) and len(tree) == 2:
        return prettyprint(tree[1], indent, omit_singles)
    if isinstance(tree, tuple):
        return " (\n%s\n" % "".join([prettyprint(x, indent+1) for x in tree]) + \
            (" " * indent) + ")"
    elif isinstance(tree, int):
        return (" " * indent) + name_that_thing(tree)
    else:
        return " " + repr(tree)


def isclass(obj):
    "Simple test to see if something is a class."
    return issubclass(type(obj), type)


def aliased_models(module):
    """
    Given a models module, returns a dict mapping all alias imports of models
    (e.g. import Foo as Bar) back to their original names. Bug #134.
    """
    aliases = {}
    for name, obj in module.__dict__.items():
        if isclass(obj) and issubclass(obj, models.Model) and obj is not models.Model:
            # Test to see if this has a different name to what it should
            if name != obj._meta.object_name:
                aliases[name] = obj._meta.object_name
    return aliases
    


class STTree(object):
    
    "A syntax tree wrapper class."
    
    def __init__(self, tree):
        self.tree = tree
    
    
    def __eq__(self, other):
        return other.tree == self.tree
    
    
    def __hash__(self):
        return hash(self.tree)
    
    
    @property
    def root(self):
        return self.tree[0]
    
    
    @property
    def value(self):
        return self.tree
    
    
    def walk(self, recursive=True):
        """
        Yields (symbol, subtree) for the entire subtree.
        Comes out with node 1, node 1's children, node 2, etc.
        """
        stack = [self.tree]
        done_outer = False
        while stack:
            atree = stack.pop()
            if isinstance(atree, tuple):
                if done_outer:
                    yield atree[0], STTree(atree)
                if recursive or not done_outer:
                    for bit in reversed(atree[1:]):
                        stack.append(bit)
                    done_outer = True
    
    
    def flatten(self):
        "Yields the tokens/symbols in the tree only, in order."
        bits = []
        for sym, subtree in self.walk():
            if sym in token_map:
                bits.append(sym)
            elif sym == token.NAME:
                bits.append(subtree.value)
            elif sym == token.STRING:
                bits.append(subtree.value)
            elif sym == token.NUMBER:
                bits.append(subtree.value)
        return bits

    
    def reform(self):
        "Prints how the tree's input probably looked."
        return reform(self.flatten())
    
    
    def findAllType(self, ntype, recursive=True):
        "Returns all nodes with the given type in the tree."
        for symbol, subtree in self.walk(recursive=recursive):
            if symbol == ntype:
                yield subtree
    
    
    def find(self, selector):
        """
        Searches the syntax tree with a CSS-like selector syntax.
        You can use things like 'suite simple_stmt', 'suite, simple_stmt'
        or 'suite > simple_stmt'. Not guaranteed to return in order.
        """
        # Split up the overall parts
        patterns = [x.strip() for x in selector.split(",")]
        results = []
        for pattern in patterns:
            # Split up the parts
            parts = re.split(r'(?:[\s]|(>))+', pattern)
            # Take the first part, use it for results
            if parts[0] == "^":
                subresults = [self]
            else:
                subresults = list(self.findAllType(thing_that_name(parts[0])))
            recursive = True
            # For each remaining part, do something
            for part in parts[1:]:
                if not subresults:
                    break
                if part == ">":
                    recursive = False
                elif not part:
                    pass
                else:
                    thing = thing_that_name(part)
                    newresults = [
                        list(tree.findAllType(thing, recursive))
                        for tree in subresults
                    ]
                    subresults = []
                    for stuff in newresults:
                        subresults.extend(stuff)
                    recursive = True
            results.extend(subresults)
        return results
    
    
    def __str__(self):
        return prettyprint(self.tree)
    __repr__ = __str__
    

def get_model_tree(model):
    # Get the source of the model's file
    source = open(inspect.getsourcefile(model)).read().replace("\r\n", "\n").replace("\r","\n") + "\n"
    tree = STTree(parser.suite(source).totuple())
    # Now, we have to find it
    for poss in tree.find("compound_stmt"):
        if poss.value[1][0] == symbol.classdef and \
           poss.value[1][2][1].lower() == model.__name__.lower():
            # This is the tree
            return poss


token_map = {
    token.DOT: ".",
    token.LPAR: "(",
    token.RPAR: ")",
    token.EQUAL: "=",
    token.EQEQUAL: "==",
    token.COMMA: ",",
    token.LSQB: "[",
    token.RSQB: "]",
    token.AMPER: "&",
    token.BACKQUOTE: "`",
    token.CIRCUMFLEX: "^",
    token.CIRCUMFLEXEQUAL: "^=",
    token.COLON: ":",
    token.DOUBLESLASH: "//",
    token.DOUBLESLASHEQUAL: "//=",
    token.DOUBLESTAR: "**",
    token.DOUBLESLASHEQUAL: "**=",
    token.GREATER: ">",
    token.LESS: "<",
    token.GREATEREQUAL: ">=",
    token.LESSEQUAL: "<=",
    token.LBRACE: "{",
    token.RBRACE: "}",
    token.SEMI: ";",
    token.PLUS: "+",
    token.MINUS: "-",
    token.STAR: "*",
    token.SLASH: "/",
    token.VBAR: "|",
    token.PERCENT: "%",
    token.TILDE: "~",
    token.AT: "@",
    token.NOTEQUAL: "!=",
    token.LEFTSHIFT: "<<",
    token.RIGHTSHIFT: ">>",
    token.LEFTSHIFTEQUAL: "<<=",
    token.RIGHTSHIFTEQUAL: ">>=",
    token.PLUSEQUAL: "+=",
    token.MINEQUAL: "-=",
    token.STAREQUAL: "*=",
    token.SLASHEQUAL: "/=",
    token.VBAREQUAL: "|=",
    token.PERCENTEQUAL: "%=",
    token.AMPEREQUAL: "&=",
}


def reform(bits):
    "Returns the string that the list of tokens/symbols 'bits' represents"
    output = ""
    for bit in bits:
        if bit in token_map:
            output += token_map[bit]
        elif bit[0] in [token.NAME, token.STRING, token.NUMBER]:
            if keyword.iskeyword(bit[1]):
                output += " %s " % bit[1]
            else:
                if bit[1] not in symbol.sym_name:
                    output += bit[1]
    return output


def parse_arguments(argstr):
    """
    Takes a string representing arguments and returns the positional and 
    keyword argument list and dict respectively.
    All the entries in these are python source, except the dict keys.
    """
    # Get the tree
    tree = STTree(parser.suite(argstr).totuple())

    # Initialise the lists
    curr_kwd = None
    args = []
    kwds = {}
    
    # Walk through, assigning things
    testlists = tree.find("testlist")
    for i, testlist in enumerate(testlists):
        # BTW: A testlist is to the left or right of an =.
        items = list(testlist.walk(recursive=False))
        for j, item in enumerate(items):
            if item[0] == symbol.test:
                if curr_kwd:
                    kwds[curr_kwd] = item[1].reform()
                    curr_kwd = None
                elif j == len(items)-1 and i != len(testlists)-1:
                    # Last item in a group must be a keyword, unless it's last overall
                    curr_kwd = item[1].reform()
                else:
                    args.append(item[1].reform())
    return args, kwds


def extract_field(tree):
    # Collapses the tree and tries to parse it as a field def
    bits = tree.flatten()
    ## Check it looks right:
    # Second token should be equals
    if len(bits) < 2 or bits[1] != token.EQUAL:
        return
    ## Split into meaningful sections
    name = bits[0][1]
    declaration = bits[2:]
    # Find the first LPAR; stuff before that is the class.
    try:
        lpar_at = declaration.index(token.LPAR)
    except ValueError:
        return
    clsname = reform(declaration[:lpar_at])
    # Now, inside that, find the last RPAR, and we'll take the stuff between
    # them as the arguments
    declaration.reverse()
    rpar_at = (len(declaration) - 1) - declaration.index(token.RPAR)
    declaration.reverse()
    args = declaration[lpar_at+1:rpar_at]
    # Now, extract the arguments as a list and dict
    try:
        args, kwargs = parse_arguments(reform(args))
    except SyntaxError:
        return
    # OK, extract and reform it
    return name, clsname, args, kwargs
    


def get_model_fields(model, m2m=False):
    """
    Given a model class, will return the dict of name: field_constructor
    mappings.
    """
    tree = get_model_tree(model)
    if tree is None:
        return None
    possible_field_defs = tree.find("^ > classdef > suite > stmt > simple_stmt > small_stmt > expr_stmt")
    field_defs = {}
    
    # Get aliases, ready for alias fixing (#134)
    try:
        aliases = aliased_models(models.get_app(model._meta.app_label))
    except ImproperlyConfigured:
        aliases = {}
    
    # Go through all the found defns, and try to parse them
    for pfd in possible_field_defs:
        field = extract_field(pfd)
        if field:
            field_defs[field[0]] = field[1:]

    inherited_fields = {}
    # Go through all bases (that are themselves models, but not Model)
    for base in model.__bases__:
        if base != models.Model and issubclass(base, models.Model):
            inherited_fields.update(get_model_fields(base, m2m))
    
    # Now, go through all the fields and try to get their definition
    source = model._meta.local_fields[:]
    if m2m:
        source += model._meta.local_many_to_many
    fields = SortedDict()
    for field in source:
        # Get its name
        fieldname = field.name
        if isinstance(field, (models.related.RelatedObject, generic.GenericRel)):
            continue
        # Now, try to get the defn
        if fieldname in field_defs:
            fields[fieldname] = field_defs[fieldname]
        # Try the South definition workaround?
        elif hasattr(field, 'south_field_triple'):
            fields[fieldname] = field.south_field_triple()
        elif hasattr(field, 'south_field_definition'):
            print "Your custom field %s provides the outdated south_field_definition method.\nPlease consider implementing south_field_triple too; it's more reliably evaluated." % field
            fields[fieldname] = field.south_field_definition()
        # Try a parent?
        elif fieldname in inherited_fields:
            fields[fieldname] = inherited_fields[fieldname]
        # Is it a _ptr?
        elif fieldname.endswith("_ptr"):
            fields[fieldname] = ("models.OneToOneField", ["orm['%s.%s']" % (field.rel.to._meta.app_label, field.rel.to._meta.object_name)], {})
        # Try a default for 'id'.
        elif fieldname == "id":
            fields[fieldname] = ("models.AutoField", [], {"primary_key": "True"})
        else:
            fields[fieldname] = None
    
    # Now, try seeing if we can resolve the values of defaults, and fix aliases.
    for field, defn in fields.items():
        
        if not isinstance(defn, (list, tuple)):
            continue # We don't have a defn for this one, or it's a string
        
        # Fix aliases if we can (#134)
        for i, arg in enumerate(defn[1]):
            if arg in aliases:
                defn[1][i] = aliases[arg]
        
        # Fix defaults if we can
        for arg, val in defn[2].items():
            if arg in ['default']:
                try:
                    # Evaluate it in a close-to-real fake model context
                    real_val = eval(val, __import__(model.__module__, {}, {}, ['']).__dict__, model.__dict__)
                # If we can't resolve it, stick it in verbatim
                except:
                    pass # TODO: Raise nice error here?
                # Hm, OK, we got a value. Callables are not frozen (see #132, #135)
                else:
                    if callable(real_val):
                        # HACK
                        # However, if it's datetime.now, etc., that's special
                        for datetime_key in datetime.datetime.__dict__.keys():
                            # No, you can't use __dict__.values. It's different.
                            dtm = getattr(datetime.datetime, datetime_key)
                            if real_val == dtm:
                                if not val.startswith("datetime.datetime"):
                                    defn[2][arg] = "datetime." + val
                                break
                    else:
                        defn[2][arg] = repr(real_val)
        
    
    return fields
